import paddle
import paddle.nn.functional as F
from paddle.vision.models.densenet import ConvBNLayer


class BottleneckBlock(paddle.nn.layer):
    def __init__(self,
                 num_channels,
                 num_filters,
                 stride,
                 shortcut=True):
        super(BottleneckBlock,self).__init__()

        # 创建第一个卷积层 1x1
        self.conv0 = ConvBNLayer(num_channels=num_channels,
                                 num_filters=num_filters,
                                 filter_size=1,
                                 act='relu')

        # 创建第二个卷积层 3x3
        self.conv1 = ConvBNLayer(num_channels=num_filters,
                                 num_filters=num_filters,
                                 filter_size=3,
                                 stride=stride,
                                 act='relu')

        # 创建第三个卷积 1x1,但输出通道数乘以4
        self.conv2 = ConvBNLayer(num_channels=num_filters,
                                 num_filters=num_filters * 4,
                                 filter_size=1,
                                 act=None)

        # 如果conv2的输出跟此残差块的输入数据形状一致，则shortcut=True
        # 否则shortcut = False，添加1个1x1的卷积作用在输入数据上，使其形状变成跟conv2一致
        if not shortcut:
            self.short = ConvBNLayer(num_channels = num_channels,
                                     num_filters=num_filters * 4,
                                     filter_size=1,
                                     stride=stride)

        self.shortcut = shortcut

        self._num_channels_out = num_filters * 4

    def forward(self,inputs):
        y = self.conv0(inputs)
        conv1 = self.conv1(y)
        conv2 = self.conv2(conv1)

        # 如果shortcut=True，直接将input跟conv2的输出相加
        # 否则需要对inputs进行一次卷积，将形状调整成跟conv2输出一致
        if self.shortcut:
            short = inputs
        else:
            short = self.short(inputs)

        y = paddle.add(x=short, y=conv2)
        y = F.relu(y)
        return y